import numpy as np
import torch

torch_param = torch.load("/lgsl_data1/huawei_wz/iter_0000010_HF_FLM2_fix_head_new/pytorch_model.bin")
pt_dense_weight = torch_param["transformer.h.0.mlp.c_proj.weight"].to(torch.float32).numpy()
np.save("pt_dense_weight.npy", pt_dense_weight)
# pt_dense_x = np.load("/lgsl_data1/huawei_wz/Megatron-FLM-huawei/numpys/selfattn_input_hidden_states.npy")
# compare_acc.compare(ms_weight, pt_weight, save_name_prefix="ms_dense_layer")
